#!/usr/bin/env python3

# Copyright (c) the Selfie Project authors. All rights reserved.
# Please see the AUTHORS file for details. Use of this source code is
# governed by a BSD license that can be found in the LICENSE file.

# Selfie is a project of the Computational Systems Group at the
# Department of Computer Sciences of the University of Salzburg
# in Austria. For further information and code please refer to:

# selfie.cs.uni-salzburg.at

# Bitme is a bounded model checker for BTOR2 models using
# the Z3 and bitwuzla SMT solvers as reasoning engines.

# Bitme is designed to work with BTOR2 models generated by rotor
# for modeling RISC-V machines and RISC-V code. Rotor is a tool
# that is part of the selfie system.

# ------------------------------------------------------------

# for debugging segfaults: import faulthandler; faulthandler.enable()

import btor2

import z3interface
import bitwuzlainterface

import bvdd as BVDD
import cflobvdd as CFLOBVDD

import ctypes

try:
    rotor = ctypes.cdll.LoadLibrary("rotor")
    is_rotor_present = True
except OSError:
    print("rotor is not available")
    is_rotor_present = False

import concurrent.futures

class Futures:
    executor = concurrent.futures.ThreadPoolExecutor()

    def __init__(self):
        self.cache_futures = {}

    def wait_step(self, step):
        assert step >= 0
        if step in self.cache_futures:
            concurrent.futures.wait([self.cache_futures[step]])
        return self.get_step(step)

    def fork_step(self, step):
        assert step >= 0
        if step not in self.cache_futures:
            self.cache_futures[step] = Futures.executor.submit(self.get_step, step)

from math import log2
from math import ceil

class Values:
    BVDD = False
    BVDD_level = 0
    BVDD_number_of_inputs = 0
    BVDD_input = {}
    BVDD_index = {}

    CFLOBVDD = False
    CFLOBVDD_level = 0
    CFLOBVDD_swap_level = 0
    CFLOBVDD_fork_level = 0
    CLOBVDD_number_of_inputs = 0
    CFLOBVDD_input = {}
    CFLOBVDD_index = {}

    total_number_of_constants = 0
    total_number_of_values = 0

    total_number_of_distinct_inputs = 0
    max_number_of_connections = 0

    false = None
    true = None

    def __init__(self, sid_line, value, var_line = None, bvdd = None, cflobvdd = None):
        assert isinstance(sid_line, Bitvector)
        self.sid_line = sid_line
        self.bvdd = bvdd
        self.cflobvdd = cflobvdd

        if isinstance(value, bool) or isinstance(value, int):
            assert sid_line.is_unsigned_value(value)

            if Values.BVDD:
                self.bvdd = BVDD.BVDD.constant(value)
            if Values.CFLOBVDD:
                self.cflobvdd = CFLOBVDD.CFLOBVDD.byte_constant(
                    Values.CFLOBVDD_level,
                    Values.CFLOBVDD_swap_level,
                    Values.CFLOBVDD_fork_level,
                    Values.CFLOBVDD_number_of_inputs,
                    value)

            Values.total_number_of_constants += 1
        elif isinstance(var_line, Variable):
            if Values.BVDD:
                self.bvdd = BVDD.BVDD.projection(Values.BVDD_index[var_line])
            if Values.CFLOBVDD:
                self.cflobvdd = CFLOBVDD.CFLOBVDD.byte_projection(
                    Values.CFLOBVDD_level,
                    Values.CFLOBVDD_swap_level,
                    Values.CFLOBVDD_fork_level,
                    Values.CFLOBVDD_number_of_inputs,
                    Values.CFLOBVDD_index[var_line],
                    True)

            Values.total_number_of_constants += 2**var_line.sid_line.size
            if Values.BVDD:
                Values.total_number_of_values += self.bvdd.number_of_outputs()
                Values.total_number_of_distinct_inputs += self.bvdd.number_of_distinct_inputs()
            else:
                assert Values.CFLOBVDD
                Values.total_number_of_values += self.cflobvdd.number_of_outputs()
                Values.total_number_of_distinct_inputs += self.cflobvdd.number_of_distinct_inputs()

        if Values.BVDD:
            Values.max_number_of_connections = max(Values.max_number_of_connections,
                self.bvdd.number_of_connections())
        else:
            assert Values.CFLOBVDD
            Values.max_number_of_connections = max(Values.max_number_of_connections,
                self.cflobvdd.number_of_connections())

        # for debugging assert self.is_consistent()

    def __str__(self):
        return f"{self.sid_line}: {self.bvdd} {self.cflobvdd}"

    def match_sorts(self, values):
        return self.sid_line.match_sorts(values.sid_line)

    def is_consistent(self):
        return self.bvdd.is_consistent()

    # BVDD adapter

    def get_input_expression(var_line, inputs):
        if inputs == 0:
            return []
        else:
            assert inputs > 0

            inputs_sid_line = Bitvec(btor2.Parser.next_nid(), 2**var_line.sid_line.size,
                var_line.comment, var_line.line_no)
            inputs_zero_line = Constd(btor2.Parser.next_nid(), inputs_sid_line, 0,
                var_line.comment, var_line.line_no)
            inputs_one_line = Constd(btor2.Parser.next_nid(), inputs_sid_line, 1,
                var_line.comment, var_line.line_no)

            if inputs.bit_count() == 1:
                # true if value of var_line is in singleton-set inputs
                comparison_line = Comparison(btor2.Parser.next_nid(), btor2.OP_EQ, Bool.boolean,
                    Constd(btor2.Parser.next_nid(), var_line.sid_line,
                        int(log2(inputs)),
                        var_line.comment, var_line.line_no),
                    var_line,
                    var_line.comment, var_line.line_no)
            else:
                # true if value of var_line is in inputs
                comparison_line = Comparison(btor2.Parser.next_nid(), btor2.OP_NEQ, Bool.boolean,
                    # check if value of var_line is in inputs by masking inputs with 2**var_line
                    Logical(btor2.Parser.next_nid(), btor2.OP_AND, inputs_sid_line,
                        Constd(btor2.Parser.next_nid(), inputs_sid_line,
                            inputs,
                            var_line.comment, var_line.line_no),
                        # calculate 2**var_line by shifting 1 left by value of var_line
                        Computation(btor2.Parser.next_nid(), btor2.OP_SLL, inputs_sid_line,
                            inputs_one_line,
                            Ext(btor2.Parser.next_nid(), btor2.OP_UEXT, inputs_sid_line, var_line,
                                2**var_line.sid_line.size - var_line.sid_line.size,
                                var_line.comment, var_line.line_no),
                            var_line.comment, var_line.line_no),
                        var_line.comment, var_line.line_no),
                    inputs_zero_line,
                    var_line.comment, var_line.line_no)
            return [comparison_line]

    def get_bvdd_node_expression(sid_line, bvdd, sbdd, index = 0):
        if isinstance(bvdd, bool) or isinstance(bvdd, int):
            return Constd(btor2.Parser.next_nid(), sid_line, int(bvdd),
                "domain-propagated value", 0)
        elif bvdd.is_dont_care():
            return Values.get_bvdd_node_expression(sid_line,
                bvdd.get_dont_care_output(), sbdd, index + 1)
        else:
            var_line = Values.BVDD_input[index]
            exp_line = None
            s2o = bvdd.get_s2o()
            # assert s2o is sorted by inputs
            for inputs in s2o:
                output = s2o[inputs]
                if exp_line is None:
                    # reachable only if input value is in inputs
                    exp_line = Values.get_bvdd_node_expression(sid_line, output, sbdd, index + 1)
                else:
                    if sbdd:
                        assert 0 <= inputs < 256
                        inputs = 2**inputs
                    exp_line = Ite(btor2.Parser.next_nid(), sid_line,
                        Values.get_input_expression(var_line, inputs)[0],
                        Values.get_bvdd_node_expression(sid_line, output, sbdd, index + 1),
                        exp_line,
                        var_line.comment, var_line.line_no)
        return exp_line

    def get_bvdd_expression(self):
        return Values.get_bvdd_node_expression(self.sid_line, self.bvdd,
            not (isinstance(self.bvdd, BVDD.SBDD_s2o) or
                isinstance(self.bvdd, BVDD.SBDD_o2s)))

    # CFLOBVDD adapter

    def get_logical_expression(op, paths):
        if not paths:
            return []
        else:
            logical_line = None
            for path_line in paths:
                if logical_line is None:
                    logical_line = path_line
                else:
                    logical_line = Logical(btor2.Parser.next_nid(), op, Bool.boolean,
                        logical_line,
                        path_line,
                        path_line.comment, path_line.line_no)
            return [logical_line]

    def get_path_expression(paths):
        path_expression = []
        for path in paths:
            if isinstance(path[0], int):
                assert isinstance(path[1], int)
                index_i = path[0]
                inputs = path[1]
                path_expression += Values.get_input_expression(Values.CFLOBVDD_input[index_i], inputs)
            else:
                a_paths = Values.get_path_expression(path[0])
                b_paths = Values.get_path_expression(path[1])
                path_expression += Values.get_logical_expression(btor2.OP_AND, a_paths + b_paths)
        return Values.get_logical_expression(btor2.OP_OR, path_expression)

    def get_cflobvdd_expression(self):
        cflobvdd = self.cflobvdd
        exp_line = None
        for exit_i in cflobvdd.outputs:
            output_line = Constd(btor2.Parser.next_nid(), self.sid_line,
                int(cflobvdd.outputs[exit_i]),
                "domain-propagated value", 0)
            if len(cflobvdd.outputs) == 1:
                # dont-care output
                return output_line
            elif exp_line is None:
                # reachable only if input value is in inputs
                exp_line = output_line
            else:
                input_line = Values.get_path_expression(cflobvdd.grouping.get_paths(exit_i))
                assert input_line
                exp_line = Ite(btor2.Parser.next_nid(), self.sid_line,
                    input_line[0],
                    output_line,
                    exp_line,
                    self.sid_line.comment, self.sid_line.line_no)
        return exp_line

    # expressions

    def FALSE():
        if Values.false is None:
            Values.false = Values(Bool.boolean, False)
        return Values.false

    def TRUE():
        if Values.true is None:
            Values.true = Values(Bool.boolean, True)
        return Values.true

    def is_always_false(self):
        assert isinstance(self.sid_line, Bool)
        if Values.BVDD and Values.CFLOBVDD:
            assert self.bvdd.is_always_false() == self.cflobvdd.is_always_false()
        if Values.BVDD:
            return self.bvdd.is_always_false()
        if Values.CFLOBVDD:
            return self.cflobvdd.is_always_false()

    def is_always_true(self):
        assert isinstance(self.sid_line, Bool)
        if Values.BVDD and Values.CFLOBVDD:
            assert self.bvdd.is_always_true() == self.cflobvdd.is_always_true()
        if Values.BVDD:
            return self.bvdd.is_always_true()
        if Values.CFLOBVDD:
            return self.cflobvdd.is_always_true()

    def get_expression(self):
        # naive transition from domain propagation to bit blasting
        assert isinstance(self.sid_line, Bitvector)
        if Values.BVDD:
            return self.get_bvdd_expression()
        else:
            assert Values.CFLOBVDD
            return self.get_cflobvdd_expression()

    # per-value semantics of value sets

    # unary operators

    def apply_unary(self, sid_line, op):
        bvdd = cflobvdd = None
        if Values.BVDD:
            bvdd = self.bvdd.compute_unary(op)
        if Values.CFLOBVDD:
            cflobvdd = self.cflobvdd.unary_apply_and_reduce(op, sid_line.size)
        return Values(sid_line, None, None, bvdd, cflobvdd)

    def SignExt(self, sid_line):
        assert isinstance(self.sid_line, Bitvec)
        return self.apply_unary(sid_line, lambda x: self.sid_line.get_signed_value(x) % 2**sid_line.size)

    def ZeroExt(self, sid_line):
        assert isinstance(self.sid_line, Bitvec)
        return self.apply_unary(sid_line, lambda x: x)

    def Extract(self, sid_line, u, l):
        assert isinstance(self.sid_line, Bitvec)
        return self.apply_unary(sid_line, lambda x: (x & 2**(u + 1) - 1) >> l)

    def Not(self):
        assert isinstance(self.sid_line, Bool)
        return self.apply_unary(self.sid_line, lambda x: not x)

    def __invert__(self):
        assert isinstance(self.sid_line, Bitvec)
        return self.apply_unary(self.sid_line, lambda x: ~x % 2**self.sid_line.size)

    def Inc(self):
        assert isinstance(self.sid_line, Bitvec)
        return self.apply_unary(self.sid_line, lambda x: (x + 1) % 2**self.sid_line.size)

    def Dec(self):
        assert isinstance(self.sid_line, Bitvec)
        return self.apply_unary(self.sid_line, lambda x: (x - 1) % 2**self.sid_line.size)

    def __neg__(self):
        assert isinstance(self.sid_line, Bitvec)
        return self.apply_unary(self.sid_line, lambda x: -x % 2**self.sid_line.size)

    # binary operators

    def apply_binary(self, sid_line, values, op):
        assert isinstance(values, Values), f"{self} {op} {values}"
        bvdd = cflobvdd = None
        if Values.BVDD:
            bvdd = self.bvdd.compute_binary(op, values.bvdd)
        if Values.CFLOBVDD:
            cflobvdd = self.cflobvdd.binary_apply_and_reduce(values.cflobvdd, op, sid_line.size)
        return Values(sid_line, None, None, bvdd, cflobvdd)

    def Implies(self, values):
        assert isinstance(self.sid_line, Bool)
        if self.is_always_false():
            return Values.TRUE()
        else:
            # lazy evaluation of implied values
            assert isinstance(values, Values) and isinstance(values.sid_line, Bool)
            return self.apply_binary(Bool.boolean, values, lambda x, y: (not x) or y)

    def __eq__(self, values):
        assert isinstance(self.sid_line, Bitvector) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(Bool.boolean, values, lambda x, y: x == y)

    def __ne__(self, values):
        assert isinstance(self.sid_line, Bitvector) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(Bool.boolean, values, lambda x, y: x != y)

    def __gt__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(Bool.boolean, values,
            lambda x, y: self.sid_line.get_signed_value(x) > values.sid_line.get_signed_value(y))

    def UGT(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(Bool.boolean, values, lambda x, y: x > y)

    def __ge__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(Bool.boolean, values,
            lambda x, y: self.sid_line.get_signed_value(x) >= values.sid_line.get_signed_value(y))

    def UGE(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(Bool.boolean, values, lambda x, y: x >= y)

    def __lt__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(Bool.boolean, values,
            lambda x, y: self.sid_line.get_signed_value(x) < values.sid_line.get_signed_value(y))

    def ULT(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(Bool.boolean, values, lambda x, y: x < y)

    def __le__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(Bool.boolean, values,
            lambda x, y: self.sid_line.get_signed_value(x) <= values.sid_line.get_signed_value(y))

    def ULE(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(Bool.boolean, values, lambda x, y: x <= y)

    def And(self, values):
        assert isinstance(self.sid_line, Bool)
        if self.is_always_false():
            return Values.FALSE()
        else:
            # lazy evaluation of second operand
            assert isinstance(values, Values) and isinstance(values.sid_line, Bool)
            return self.apply_binary(Bool.boolean, values, lambda x, y: x and y)

    def Or(self, values):
        assert isinstance(self.sid_line, Bool)
        if self.is_always_true():
            return Values.TRUE()
        else:
            # lazy evaluation of second operand
            assert isinstance(values, Values) and isinstance(values.sid_line, Bool)
            return self.apply_binary(Bool.boolean, values, lambda x, y: x or y)

    def Xor(self, values):
        assert isinstance(self.sid_line, Bool) and isinstance(values.sid_line, Bool)
        return self.apply_binary(Bool.boolean, values, lambda x, y: x != y)

    def __and__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values, lambda x, y: x & y)

    def __or__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values, lambda x, y: x | y)

    def __xor__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values, lambda x, y: x ^ y)

    def __lshift__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values, lambda x, y: (x << y) % 2**self.sid_line.size)

    def LShR(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values, lambda x, y: (x >> y) % 2**self.sid_line.size)

    def __rshift__(self, values):
        # right shift operator computes arithmetic right shift in Python
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values,
            lambda x, y: (self.sid_line.get_signed_value(x) >> y) % 2**self.sid_line.size)

    def __add__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values, lambda x, y: (x + y) % 2**self.sid_line.size)

    def __sub__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values, lambda x, y: (x - y) % 2**self.sid_line.size)

    def __mul__(self, values):
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values, lambda x, y: (x * y) % 2**self.sid_line.size)

    def __truediv__(self, values):
        # using the integer portion of division, not floor division with the // operator,
        # because int(x / y) != x // y in Python if x < 0 or y < 0 since
        # the integer portion of division truncates towards 0 whereas
        # floor division truncates towards negative infinity
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values,
            lambda x, y: (int(self.sid_line.get_signed_value(x) / values.sid_line.get_signed_value(y))
                if not (y == 0 or (self.sid_line.get_signed_value(x) == -2**(self.sid_line.size - 1) and
                    values.sid_line.get_signed_value(y) == -1))
                else -1 if y == 0 else -2**(self.sid_line.size - 1)) % 2**self.sid_line.size)

    def UDiv(self, values):
        # using floor division is ok since x >= 0 and y >= 0
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values,
            lambda x, y: x // y if y != 0 else 2**self.sid_line.size - 1)

    def SRem(self, values):
        # using the integer portion of division, not the % operator,
        # because x % y != x - int(x / y) * y in Python if x < 0 since
        # the % operator in Python computes Euclidean modulus, not remainder,
        # such that x // y * y + x % y == x holds in Python for all x and y even if x < 0
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values,
            lambda x, y: (self.sid_line.get_signed_value(x) -
                    int(self.sid_line.get_signed_value(x) / values.sid_line.get_signed_value(y)) *
                        values.sid_line.get_signed_value(y))
                    % 2**self.sid_line.size
                if not (y == 0 or (self.sid_line.get_signed_value(x) == -2**(self.sid_line.size - 1) and
                    values.sid_line.get_signed_value(y) == -1))
                else x if y == 0 else 0)

    def URem(self, values):
        # using the % operator is ok since x >= 0 and y >= 0
        assert isinstance(self.sid_line, Bitvec) and self.sid_line.match_sorts(values.sid_line)
        return self.apply_binary(self.sid_line, values, lambda x, y: x % y if y != 0 else x)

    def Concat(self, values, sid_line):
        assert isinstance(self.sid_line, Bitvec) and isinstance(values.sid_line, Bitvec)
        return self.apply_binary(sid_line, values, lambda x, y: (x << values.sid_line.size) + y)

    # ternary operator

    def ite(self, values2, values3):
        assert isinstance(self.sid_line, Bool)
        assert isinstance(values2, Values) and isinstance(values3, Values)
        assert values2.match_sorts(values3)
        bvdd = cflobvdd = None
        if Values.BVDD:
            bvdd = self.bvdd.compute_ite(values2.bvdd, values3.bvdd)
        if Values.CFLOBVDD:
            cflobvdd = self.cflobvdd.ternary_apply_and_reduce(values2.cflobvdd, values3.cflobvdd,
                lambda x, y, z: y if x else z, values2.sid_line.size)
        return Values(values2.sid_line, None, None, bvdd, cflobvdd)

LAMBDAS = True

UNROLL = False

PROPAGATE = None
PROPAGATE_UNARY = True
PROPAGATE_BINARY = True
PROPAGATE_ITE = True

class Line(btor2.Line, z3interface.Z3, bitwuzlainterface.Bitwuzla):
    def __init__(self):
        z3interface.Z3.__init__(self)
        bitwuzlainterface.Bitwuzla.__init__(self)

class Sort(Line, btor2.Sort):
    def __init__(self):
        Line.__init__(self)

class Bitvector(Sort, btor2.Bitvector):
    def __init__(self):
        Sort.__init__(self)

class Bool(Bitvector, btor2.Bool, z3interface.Bool, bitwuzlainterface.Bool):
    def __init__(self, nid, comment, line_no):
        Bitvector.__init__(self)
        btor2.Bool.__init__(self, nid, comment, line_no)

class Bitvec(Bitvector, btor2.Bitvec, z3interface.Bitvec, bitwuzlainterface.Bitvec):
    def __init__(self, nid, size, comment, line_no):
        Bitvector.__init__(self)
        btor2.Bitvec.__init__(self, nid, size, comment, line_no)

class Array(Sort, btor2.Array, z3interface.Array, bitwuzlainterface.Array):
    def __init__(self, nid, array_size_line, element_size_line, comment, line_no):
        Sort.__init__(self)
        btor2.Array.__init__(self, nid, array_size_line, element_size_line, comment, line_no)

class Expression(Line, btor2.Expression, z3interface.Expression, bitwuzlainterface.Expression):
    def __init__(self):
        Line.__init__(self)
        self.cache_values = {}
        z3interface.Expression.__init__(self)
        bitwuzlainterface.Expression.__init__(self)

    def get_expression(self):
        return self

    def get_values(self, step):
        assert step >= 0
        if UNROLL or PROPAGATE is not None:
            # versioning needed for support of branching in bitme solver
            if step not in self.cache_values or self.cache_values[step][1] not in Bitme_Solver.versions:
                self.cache_values[step] = (self.compute_values(step), Bitme_Solver.version)
            return self.cache_values[step][0]
        else:
            return self

class Constant(Expression, btor2.Constant, z3interface.Constant, bitwuzlainterface.Constant):
    def __init__(self):
        Expression.__init__(self)

    def compute_values(self, step):
        assert step == 0
        if PROPAGATE is not None:
            if isinstance(self.sid_line, Bool):
                return Values.TRUE() if bool(self.value) else Values.FALSE()
            else:
                assert isinstance(self.sid_line, Bitvec)
                return Values(self.sid_line, self.value)
        else:
            return self

    def get_values(self, step):
        assert step >= 0
        return super().get_values(0)

class Zero(Constant, btor2.Zero):
    def __init__(self, nid, sid_line, symbol, comment, line_no):
        Constant.__init__(self)
        btor2.Zero.__init__(self, nid, sid_line, symbol, comment, line_no)

class One(Constant, btor2.One):
    def __init__(self, nid, sid_line, symbol, comment, line_no):
        Constant.__init__(self)
        btor2.One.__init__(self, nid, sid_line, symbol, comment, line_no)

class Constd(Constant, btor2.Constd):
    def __init__(self, nid, sid_line, value, comment, line_no):
        Constant.__init__(self)
        btor2.Constd.__init__(self, nid, sid_line, value, comment, line_no)

class Const(Constant, btor2.Const):
    def __init__(self, nid, sid_line, value, comment, line_no):
        Constant.__init__(self)
        btor2.Const.__init__(self, nid, sid_line, value, comment, line_no)

class Consth(Constant, btor2.Consth):
    def __init__(self, nid, sid_line, value, comment, line_no):
        Constant.__init__(self)
        btor2.Consth.__init__(self, nid, sid_line, value, comment, line_no)

class Constant_Array(Expression, btor2.Constant_Array, z3interface.Constant_Array, bitwuzlainterface.Constant_Array):
    def __init__(self, sid_line, constant_line):
        Expression.__init__(self)
        btor2.Constant_Array.__init__(self, sid_line, constant_line)

    def get_values(self, step):
        assert step >= 0
        return self

class Variable(Expression, btor2.Variable, z3interface.Variable):
    def __init__(self):
        Expression.__init__(self)

    def compute_values(self, step):
        assert step == 0
        if (isinstance(self.sid_line, Bitvector) and
            PROPAGATE is not None and
            self.sid_line.size <= PROPAGATE):
            return Values(self.sid_line, None, self)
        else:
            return self

class Input(Variable, btor2.Input, z3interface.Input, bitwuzlainterface.Input):
    def __init__(self, nid, sid_line, symbol, comment, line_no, index = None):
        Variable.__init__(self)
        btor2.Input.__init__(self, nid, sid_line, symbol, comment, line_no, index)

    def get_step_name(self, step):
        return self.name

    def compute_values(self, step):
        return super().compute_values(0)

class State(Variable, btor2.State, z3interface.State, bitwuzlainterface.State):
    pc = None

    def __init__(self, nid, sid_line, symbol, comment, line_no, index = None):
        Variable.__init__(self)
        btor2.State.__init__(self, nid, sid_line, symbol, comment, line_no, index)
        z3interface.State.__init__(self)
        bitwuzlainterface.State.__init__(self)
        if symbol == "core-0-pc":
            State.pc = self

    def get_step_name(self, step):
        return f"{self.name}-{step}"

    def compute_values(self, step):
        assert step >= 0
        if step == 0:
            if self.init_line is None:
                # uninitialized state
                return super().compute_values(0)
            elif UNROLL:
                return self.init_line.wait_step(0)
            else:
                return self
        elif self.next_line is None or self.next_line.exp_line is self:
            # untransitioned state or transitioned to itself
            return self.get_values(0)
        elif UNROLL:
            return self.next_line.wait_step(step - 1)
        else:
            return self

class Indexed(Expression, btor2.Indexed):
    def __init__(self):
        Expression.__init__(self)

class Ext(Indexed, btor2.Ext, z3interface.Ext, bitwuzlainterface.Ext):
    def __init__(self, nid, op, sid_line, arg1_line, w, comment, line_no):
        Indexed.__init__(self)
        btor2.Ext.__init__(self, nid, op, sid_line, arg1_line, w, comment, line_no)

    def compute_values(self, step):
        arg1_value = self.arg1_line.get_values(step)
        if PROPAGATE_UNARY and isinstance(arg1_value, Values):
            if self.op == btor2.OP_SEXT:
                return arg1_value.SignExt(self.sid_line)
            else:
                assert self.op == btor2.OP_UEXT
                return arg1_value.ZeroExt(self.sid_line)
        else:
            arg1_value = arg1_value.get_expression()
            return self.copy(arg1_value)

class Slice(Indexed, btor2.Slice, z3interface.Slice, bitwuzlainterface.Slice):
    def __init__(self, nid, sid_line, arg1_line, u, l, comment, line_no):
        Indexed.__init__(self)
        btor2.Slice.__init__(self, nid, sid_line, arg1_line, u, l, comment, line_no)

    def compute_values(self, step):
        arg1_value = self.arg1_line.get_values(step)
        if PROPAGATE_UNARY and isinstance(arg1_value, Values):
            return arg1_value.Extract(self.sid_line, self.u, self.l)
        else:
            arg1_value = arg1_value.get_expression()
            return self.copy(arg1_value)

class Unary(Expression, btor2.Unary, z3interface.Unary, bitwuzlainterface.Unary):
    def __init__(self, nid, op, sid_line, arg1_line, comment, line_no):
        Expression.__init__(self)
        btor2.Unary.__init__(self, nid, op, sid_line, arg1_line, comment, line_no)

    def compute_values(self, step):
        arg1_value = self.arg1_line.get_values(step)
        if PROPAGATE_UNARY and isinstance(arg1_value, Values):
            if self.op == btor2.OP_NOT:
                if isinstance(self.sid_line, Bool):
                    return arg1_value.Not()
                else:
                    return ~arg1_value
            elif self.op == btor2.OP_INC:
                return arg1_value.Inc()
            elif self.op == btor2.OP_DEC:
                return arg1_value.Dec()
            else:
                assert self.op == btor2.OP_NEG
                return -arg1_value
        else:
            arg1_value = arg1_value.get_expression()
            return self.copy(arg1_value)

class Binary(Expression, btor2.Binary):
    def __init__(self):
        Expression.__init__(self)

class Implies(Binary, btor2.Implies, z3interface.Implies, bitwuzlainterface.Implies):
    def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
        Binary.__init__(self)
        btor2.Implies.__init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no)

    def compute_values(self, step):
        arg1_value = self.arg1_line.get_values(step)
        if PROPAGATE_BINARY and isinstance(arg1_value, Values):
            if arg1_value.is_always_false():
                return arg1_value.Implies(None)
            else:
                # lazy evaluation of implied values
                arg2_value = self.arg2_line.get_values(step)
                if isinstance(arg2_value, Values):
                    return arg1_value.Implies(arg2_value)
        else:
            arg2_value = self.arg2_line.get_values(step)
        arg1_value = arg1_value.get_expression()
        arg2_value = arg2_value.get_expression()
        return self.copy(arg1_value, arg2_value)

class Comparison(Binary, btor2.Comparison, z3interface.Comparison, bitwuzlainterface.Comparison):
    def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
        Binary.__init__(self)
        btor2.Comparison.__init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no)

    def compute_values(self, step):
        arg1_value = self.arg1_line.get_values(step)
        arg2_value = self.arg2_line.get_values(step)
        if PROPAGATE_BINARY:
            if isinstance(arg1_value, Values) and isinstance(arg2_value, Values):
                if self.op == btor2.OP_EQ:
                    return arg1_value == arg2_value
                elif self.op == btor2.OP_NEQ:
                    return arg1_value != arg2_value
                elif self.op == btor2.OP_SGT:
                    return arg1_value > arg2_value
                elif self.op == btor2.OP_UGT:
                    return arg1_value.UGT(arg2_value)
                elif self.op == btor2.OP_SGTE:
                    return arg1_value >= arg2_value
                elif self.op == btor2.OP_UGTE:
                    return arg1_value.UGE(arg2_value)
                elif self.op == btor2.OP_SLT:
                    return arg1_value < arg2_value
                elif self.op == btor2.OP_ULT:
                    return arg1_value.ULT(arg2_value)
                elif self.op == btor2.OP_SLTE:
                    return arg1_value <= arg2_value
                else:
                    assert self.op == btor2.OP_ULTE
                    return arg1_value.ULE(arg2_value)
        arg1_value = arg1_value.get_expression()
        arg2_value = arg2_value.get_expression()
        return self.copy(arg1_value, arg2_value)

    def is_always_false(self):
        # only needed for termination check
        return False

    def get_step(self, step):
        # only needed for termination check
        assert step >= 0
        return self.get_values(step)

    def wait_step(self, step):
        # no waiting
        return self.get_step(step)

class Logical(Binary, btor2.Logical, z3interface.Logical, bitwuzlainterface.Logical):
    def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
        Binary.__init__(self)
        btor2.Logical.__init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no)

    def compute_values(self, step):
        if PROPAGATE_BINARY:
            if isinstance(self.sid_line, Bool):
                arg1_value = self.arg1_line.get_values(step)
                if isinstance(arg1_value, Values):
                    if self.op == btor2.OP_AND:
                        if arg1_value.is_always_false():
                            return arg1_value.And(None)
                        else:
                            # lazy evaluation of second operand
                            arg2_value = self.arg2_line.get_values(step)
                            if isinstance(arg2_value, Values):
                                return arg1_value.And(arg2_value)
                    elif self.op == btor2.OP_OR:
                        if arg1_value.is_always_true():
                            return arg1_value.Or(None)
                        else:
                            # lazy evaluation of second operand
                            arg2_value = self.arg2_line.get_values(step)
                            if isinstance(arg2_value, Values):
                                return arg1_value.Or(arg2_value)
                    else:
                        assert self.op == btor2.OP_XOR
                        arg2_value = self.arg2_line.get_values(step)
                        if isinstance(arg2_value, Values):
                            return arg1_value.Xor(arg2_value)
                arg2_value = self.arg2_line.get_values(step)
            else:
                arg1_value = self.arg1_line.get_values(step)
                arg2_value = self.arg2_line.get_values(step)
                if isinstance(arg1_value, Values) and isinstance(arg2_value, Values):
                    if self.op == btor2.OP_AND:
                        return arg1_value & arg2_value
                    elif self.op == btor2.OP_OR:
                        return arg1_value | arg2_value
                    else:
                        assert self.op == btor2.OP_XOR
                        return arg1_value ^ arg2_value
        else:
            arg1_value = self.arg1_line.get_values(step)
            arg2_value = self.arg2_line.get_values(step)
        arg1_value = arg1_value.get_expression()
        arg2_value = arg2_value.get_expression()
        return self.copy(arg1_value, arg2_value)

class Computation(Binary, btor2.Computation, z3interface.Computation, bitwuzlainterface.Computation):
    def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
        Binary.__init__(self)
        btor2.Computation.__init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no)

    def compute_values(self, step):
        arg1_value = self.arg1_line.get_values(step)
        arg2_value = self.arg2_line.get_values(step)
        if PROPAGATE_BINARY:
            if isinstance(arg1_value, Values) and isinstance(arg2_value, Values):
                if self.op == btor2.OP_SLL:
                    return arg1_value << arg2_value
                elif self.op == btor2.OP_SRL:
                    return arg1_value.LShR(arg2_value)
                elif self.op == btor2.OP_SRA:
                    return arg1_value >> arg2_value
                elif self.op == btor2.OP_ADD:
                    return arg1_value + arg2_value
                elif self.op == btor2.OP_SUB:
                    return arg1_value - arg2_value
                elif self.op == btor2.OP_MUL:
                    return arg1_value * arg2_value
                elif self.op == btor2.OP_SDIV:
                    return arg1_value / arg2_value
                elif self.op == btor2.OP_UDIV:
                    return arg1_value.UDiv(arg2_value)
                elif self.op == btor2.OP_SREM:
                    return arg1_value.SRem(arg2_value)
                else:
                    assert self.op == btor2.OP_UREM
                    return arg1_value.URem(arg2_value)
        arg1_value = arg1_value.get_expression()
        arg2_value = arg2_value.get_expression()
        return self.copy(arg1_value, arg2_value)

class Concat(Binary, btor2.Concat, z3interface.Concat, bitwuzlainterface.Concat):
    def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
        Binary.__init__(self)
        btor2.Concat.__init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no)

    def compute_values(self, step):
        arg1_value = self.arg1_line.get_values(step)
        arg2_value = self.arg2_line.get_values(step)
        if PROPAGATE_BINARY:
            if isinstance(arg1_value, Values) and isinstance(arg2_value, Values):
                return arg1_value.Concat(arg2_value, self.sid_line)
        arg1_value = arg1_value.get_expression()
        arg2_value = arg2_value.get_expression()
        return self.copy(arg1_value, arg2_value)

class Read(Binary, btor2.Read, z3interface.Read, bitwuzlainterface.Read):
    def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
        Binary.__init__(self)
        btor2.Read.__init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no)

    def compute_values(self, step):
        arg1_value = self.arg1_line.get_values(step).get_expression()
        arg2_value = self.arg2_line.get_values(step).get_expression()
        return self.copy(arg1_value, arg2_value)

class Ternary(Expression, btor2.Ternary):
    def __init__(self):
        Expression.__init__(self)

class Ite(Ternary, btor2.Ite, z3interface.Ite, bitwuzlainterface.Ite):
    def __init__(self, nid, sid_line, arg1_line, arg2_line, arg3_line, comment, line_no):
        Ternary.__init__(self)
        btor2.Ite.__init__(self, nid, sid_line, arg1_line, arg2_line, arg3_line, comment, line_no)

    def compute_values(self, step):
        arg1_value = self.arg1_line.get_values(step)
        if PROPAGATE_ITE and isinstance(arg1_value, Values):
            if arg1_value.is_always_true():
                arg2_value = self.arg2_line.get_values(step)
                if isinstance(arg2_value, Values):
                    return arg2_value
                else:
                    # true case holds unconditionally
                    return arg2_value.get_expression()
            elif arg1_value.is_always_false():
                arg3_value = self.arg3_line.get_values(step)
                if isinstance(arg3_value, Values):
                    return arg3_value
                else:
                    # false case holds unconditionally
                    return arg3_value.get_expression()
            else:
                # lazy evaluation of true and false case
                arg2_value = self.arg2_line.get_values(step)
                arg3_value = self.arg3_line.get_values(step)
                if isinstance(arg2_value, Values) and isinstance(arg3_value, Values):
                    return arg1_value.ite(arg2_value, arg3_value)
        else:
            arg2_value = self.arg2_line.get_values(step)
            arg3_value = self.arg3_line.get_values(step)
        arg1_value = arg1_value.get_expression()
        arg2_value = arg2_value.get_expression()
        arg3_value = arg3_value.get_expression()
        return self.copy(arg1_value, arg2_value, arg3_value)

    def get_step(self, step):
        # only needed for branching
        assert step >= 0
        return self.get_values(step)

    def wait_step(self, step):
        # no waiting
        return self.get_step(step)

class Write(Ternary, btor2.Write, z3interface.Write, bitwuzlainterface.Write):
    def __init__(self, nid, sid_line, arg1_line, arg2_line, arg3_line, comment, line_no):
        Ternary.__init__(self)
        btor2.Write.__init__(self, nid, sid_line, arg1_line, arg2_line, arg3_line, comment, line_no)

    def compute_values(self, step):
        arg1_value = self.arg1_line.get_values(step).get_expression()
        arg2_value = self.arg2_line.get_values(step).get_expression()
        arg3_value = self.arg3_line.get_values(step).get_expression()
        return self.copy(arg1_value, arg2_value, arg3_value)

class Transitional(Line, btor2.Transitional):
    def __init__(self):
        Line.__init__(self)

class Init(Transitional, btor2.Init, z3interface.Init, bitwuzlainterface.Init):
    def __init__(self, nid, sid_line, state_line, exp_line, symbol, comment, line_no, array_line = None, index = None):
        Transitional.__init__(self)
        btor2.Init.__init__(self, nid, sid_line, state_line, exp_line, symbol, comment, line_no, array_line, index)

    def get_step(self, step):
        assert step == 0, f"get init with {step} != 0"
        return self.exp_line.get_values(0)

    def wait_step(self, step):
        # no waiting
        return self.get_step(step)

class Next(Transitional, btor2.Next, z3interface.Next, bitwuzlainterface.Next, Futures):
    def __init__(self, nid, sid_line, state_line, exp_line, symbol, comment, line_no, array_line = None, index = None):
        Transitional.__init__(self)
        btor2.Next.__init__(self, nid, sid_line, state_line, exp_line, symbol, comment, line_no, array_line, index)
        z3interface.Next.__init__(self)
        bitwuzlainterface.Next.__init__(self)
        Futures.__init__(self)
        self.is_state_changing_line = None
        self.state_is_not_changing_line = None

    def get_step(self, step):
        assert step >= 0
        return self.exp_line.get_values(step)

    def is_state_changing(self):
        if self.is_state_changing_line is None:
            self.is_state_changing_line = Comparison(btor2.Parser.next_nid(),
                btor2.OP_NEQ, Bool.boolean,
                self.state_line, self.exp_line,
                f"state change check for {self.symbol}", self.line_no)
        return self.is_state_changing_line

    def state_is_not_changing(self):
        if self.state_is_not_changing_line is None:
            self.state_is_not_changing_line = Comparison(btor2.Parser.next_nid(),
                btor2.OP_EQ, Bool.boolean,
                self.state_line, self.exp_line,
                f"state is not changing for {self.symbol}", self.line_no)
        return self.state_is_not_changing_line

class Property(Line, btor2.Property, z3interface.Property, bitwuzlainterface.Property, Futures):
    def __init__(self):
        Line.__init__(self)
        Futures.__init__(self)

    def get_step(self, step):
        assert step >= 0
        return self.property_line.get_values(step)

class Constraint(Property, btor2.Constraint):
    def __init__(self, nid, property_line, symbol, comment, line_no):
        Property.__init__(self)
        btor2.Constraint.__init__(self, nid, property_line, symbol, comment, line_no)

class Bad(Property, btor2.Bad):
    def __init__(self, nid, property_line, symbol, comment, line_no):
        Property.__init__(self)
        btor2.Bad.__init__(self, nid, property_line, symbol, comment, line_no)

# parser interface

class ValuesParser(btor2.Parser):
    def get_class(self, clss_or_keyword):
        if clss_or_keyword is btor2.Bool:
            return Bool
        elif clss_or_keyword is btor2.Bitvec:
            return Bitvec
        elif clss_or_keyword is btor2.Array:
            return Array
        elif clss_or_keyword is btor2.Constant_Array:
            return Constant_Array
        elif clss_or_keyword is btor2.Zero or clss_or_keyword == btor2.Zero.keyword:
            return Zero
        elif clss_or_keyword is btor2.One or clss_or_keyword == btor2.One.keyword:
            return One
        elif clss_or_keyword is btor2.Constd or clss_or_keyword == btor2.Constd.keyword:
            return Constd
        elif clss_or_keyword is btor2.Const or clss_or_keyword == btor2.Const.keyword:
            return Const
        elif clss_or_keyword is btor2.Consth or clss_or_keyword == btor2.Consth.keyword:
            return Consth
        elif clss_or_keyword is btor2.Input or clss_or_keyword == btor2.Input.keyword:
            return Input
        elif clss_or_keyword is btor2.State or clss_or_keyword == btor2.State.keyword:
            return State
        elif clss_or_keyword is btor2.Ext or clss_or_keyword in btor2.Ext.keywords:
            return Ext
        elif clss_or_keyword is btor2.Slice or clss_or_keyword == btor2.Slice.keyword:
            return Slice
        elif clss_or_keyword is btor2.Unary or clss_or_keyword in btor2.Unary.keywords:
            return Unary
        elif clss_or_keyword is btor2.Implies or clss_or_keyword == btor2.Implies.keyword:
            return Implies
        elif clss_or_keyword is btor2.Comparison or clss_or_keyword in btor2.Comparison.keywords:
            return Comparison
        elif clss_or_keyword is btor2.Logical or clss_or_keyword in btor2.Logical.keywords:
            return Logical
        elif clss_or_keyword is btor2.Computation or clss_or_keyword in btor2.Computation.keywords:
            return Computation
        elif clss_or_keyword is btor2.Concat or clss_or_keyword == btor2.Concat.keyword:
            return Concat
        elif clss_or_keyword is btor2.Read or clss_or_keyword == btor2.Read.keyword:
            return Read
        elif clss_or_keyword is btor2.Ite or clss_or_keyword == btor2.Ite.keyword:
            return Ite
        elif clss_or_keyword is btor2.Write or clss_or_keyword == btor2.Write.keyword:
            return Write
        elif clss_or_keyword is btor2.Init or clss_or_keyword == btor2.Init.keyword:
            return Init
        elif clss_or_keyword is btor2.Next or clss_or_keyword == btor2.Next.keyword:
            return Next
        elif clss_or_keyword is btor2.Constraint or clss_or_keyword == btor2.Constraint.keyword:
            return Constraint
        elif clss_or_keyword is btor2.Bad or clss_or_keyword == btor2.Bad.keyword:
            return Bad

# console output

def get_step(step, level):
    if step is None or level is None:
        return ""
    elif level == 0:
        return f"{step}: "
    else:
        return f"{step}-{level}: "

last_message_length = 0

def print_message(message, step = None, level = None):
    global last_message_length
    if last_message_length > 0:
        print("\r%s" % (" " * last_message_length), end='\r')
    message = f"{get_step(step, level)}{message}"
    print(message, end='', flush=True)
    last_message_length = len(message) if message[-1:] != '\n' else 0

def print_separator(separator, step = None, level = None):
    print_message(f"{separator * (80 - len(get_step(step, level)))}\n", step, level)

def print_message_with_propagation_profile(message, step = None, level = None):
    if UNROLL or PROPAGATE is not None:
        string = f"({Values.total_number_of_constants} constants, "
        string += f"{Values.total_number_of_values} values, "
        string += f"{Values.total_number_of_distinct_inputs} distinct inputs, "
        string += f"{Values.max_number_of_connections} connections, "
        string += f"{Expression.total_number_of_generated_expressions} expressions) {message}"
        message = string
    print_message(message, step, level)

# bitme solver

class Bitme_Solver:
    versions = {0:None}
    version = 0
    bump = 1

    def __init__(self, z3_solver, bitwuzla_solver):
        self.z3_solver = z3_solver
        self.bitwuzla_solver = bitwuzla_solver
        self.fallback = False
        self.stack = []
        self.constraint = Values.TRUE()
        self.proven = {}
        self.unproven = {}

    def push(self):
        if self.fallback:
            if self.z3_solver:
                self.z3_solver.push()
            if self.bitwuzla_solver:
                self.bitwuzla_solver.push()
        else:
            # push before proving to enable fallback to other solvers
            self.stack.append((self.constraint, self.proven | self.unproven))
            self.prove() # may trigger fallback to other solvers
            self.proven = {}
            assert not self.unproven
            if not self.fallback:
                # proving may have strengthened constraint
                _, proven = self.stack.pop()
                self.stack.append((self.constraint, proven))
                Bitme_Solver.version = Bitme_Solver.bump
                Bitme_Solver.bump += 1
                Bitme_Solver.versions[Bitme_Solver.version] = None

    def pop(self):
        if self.fallback:
            if self.z3_solver:
                self.z3_solver.pop()
            if self.bitwuzla_solver:
                self.bitwuzla_solver.pop()
        else:
            assert self.stack
            self.constraint, self.proven = self.stack.pop()
            self.unproven = {}
            del Bitme_Solver.versions[Bitme_Solver.version]
            Bitme_Solver.version = list(Bitme_Solver.versions)[-1]

    def assert_this(self, assertions, step):
        if self.fallback:
            if self.z3_solver:
                self.z3_solver.assert_this(assertions, step)
            if self.bitwuzla_solver:
                self.bitwuzla_solver.assert_this(assertions, step)
        else:
            for assertion in assertions:
                if step not in self.unproven:
                    self.unproven[step] = {assertion:True}
                else:
                    assert assertion not in self.unproven[step]
                    self.unproven[step] |= {assertion:True}

    def assert_not_this(self, assertions, step):
        if self.fallback:
            if self.z3_solver:
                self.z3_solver.assert_not_this(assertions, step)
            if self.bitwuzla_solver:
                self.bitwuzla_solver.assert_not_this(assertions, step)
        else:
            for assertion in assertions:
                if step not in self.unproven:
                    self.unproven[step] = {assertion:False}
                else:
                    assert assertion not in self.unproven[step]
                    self.unproven[step] |= {assertion:False}

    def simplify(self):
        if self.fallback:
            if self.z3_solver:
                self.z3_solver.simplify()
            if self.bitwuzla_solver:
                self.bitwuzla_solver.simplify()

    def solve(self):
        self.fallback = True
        self.constraint = Values.TRUE()
        self.proven |= self.unproven
        self.unproven = {}
        self.stack.append((None, self.proven))
        for _, assertions in self.stack:
            for step in assertions:
                for assertion in assertions[step]:
                    if assertions[step][assertion]:
                        self.assert_this([assertion], step)
                    else:
                        self.assert_not_this([assertion], step)
            if assertions is not self.proven:
                # push with other solvers except for top of stack
                self.push()
        self.proven = {}
        self.stack = []
        return self.prove()

    def prove(self):
        if self.fallback:
            z3_SAT = False
            if self.z3_solver:
                result = self.z3_solver.prove()
                z3_SAT = self.z3_solver.is_SAT(result)
            if self.bitwuzla_solver:
                result = self.bitwuzla_solver.prove()
                bitwuzla_SAT = self.bitwuzla_solver.is_SAT(result)
                assert not self.z3_solver or z3_SAT == bitwuzla_SAT
                return bitwuzla_SAT
            return z3_SAT
        else:
            for step in self.unproven:
                for assertion in self.unproven[step]:
                    if isinstance(assertion, Transitional):
                        assertion.wait_step(step)
                for assertion in self.unproven[step]:
                    if not isinstance(assertion, Transitional):
                        condition = assertion.wait_step(step)
                        assert isinstance(condition.sid_line, Bool)
                        if isinstance(condition, Values):
                            if self.unproven[step][assertion] is True:
                                self.constraint = condition.And(self.constraint)
                            else:
                                assert self.unproven[step][assertion] is False
                                self.constraint = condition.Not().And(self.constraint)
                        else:
                            return self.solve()
            self.proven |= self.unproven
            self.unproven = {}
            return not self.constraint.is_always_false()

    def is_SAT(self, result):
        return result

    def is_UNSAT(self, result):
        return not result

    def assert_is_state_changing(self, next_line, step):
        if self.fallback:
            if self.z3_solver:
                self.z3_solver.assert_is_state_changing(next_line, step)
            if self.bitwuzla_solver:
                self.bitwuzla_solver.assert_is_state_changing(next_line, step)
        else:
            self.assert_this([next_line.is_state_changing()], step)

    def assert_state_is_not_changing(self, next_line, step):
        if self.fallback:
            if self.z3_solver:
                self.z3_solver.assert_state_is_not_changing(next_line, step)
            if self.bitwuzla_solver:
                self.bitwuzla_solver.assert_state_is_not_changing(next_line, step)
        else:
            self.assert_this([next_line.state_is_not_changing()], step)

    def print_pc(self, pc, step, level):
        if self.fallback:
            if self.z3_solver:
                self.z3_solver.print_pc(pc, step, level)
            if self.bitwuzla_solver:
                self.bitwuzla_solver.print_pc(pc, step, level)
        else:
            self.prove()
            pc_value = pc.get_values(step)
            print_message(f"{pc}\n", step, level)
            print_message("%s = %s\n" % (pc.get_step_name(step), pc_value), step, level)

    def print_inputs(self, inputs, step, level):
        if self.fallback:
            if self.z3_solver:
                self.z3_solver.print_inputs(inputs, step, level)
            if self.bitwuzla_solver:
                self.bitwuzla_solver.print_inputs(inputs, step, level)
        else:
            if Values.BVDD:
                print(self.constraint.bvdd.get_printed_BVDD(True))
            if Values.CFLOBVDD:
                print(self.constraint.cflobvdd.get_printed_CFLOBVDD(True))

    def eval_inputs(self, inputs, step):
        if self.fallback:
            if self.z3_solver:
                return self.z3_solver.eval_inputs(inputs, step)
            if self.bitwuzla_solver:
                return self.bitwuzla_solver.eval_inputs(inputs, step)
        elif Values.BVDD:
            sample = self.constraint.get_true_constraint().bvdd.sample_input_values()

            input_values = dict()
            for input_variable in inputs.values():
                if input_variable in sample:
                    input_values[input_variable.symbol] = sample[input_variable]
                    del sample[input_variable]
                else:
                    # The variable doesn't appear in our sample of the BVDD - this means any value will do
                    input_values[input_variable.symbol] = 0x42 # ord("A"), but also a nice magic value

            assert len(sample) == 0, "sanity check: all branches of the BVDD must be on input values"

            return input_values
        else:
            # TODO
            pass

# bitr concurrent bitme

def fork(kmin, kmax):
    step = 0

    while step <= kmax:
        # assert all constraints
        for constraint in Constraint.constraints.values():
            constraint.fork_step(step)

        if step >= kmin:
            # check all bad properties from kmin on
            for bad in Bad.bads.values():
                bad.fork_step(step)

        for next_line in Next.nexts.values():
            # compute next step
            next_line.fork_step(step)

        step += 1

# bitme bounded model checker

def branching_bmc(solver, kmin, kmax, args, step, level):
    while step <= kmax or args.analyzor:
        # check model up to kmax steps
        # in analyzor mode we keep going until we find a bad input

        if args.print_pc and State.pc:
            # print current program counter value of single-core rotor model
            solver.print_pc(State.pc, step, level)

        # assert all constraints
        for constraint in Constraint.constraints.values():
            print_message_with_propagation_profile(constraint.symbol, step, level)
            solver.assert_this([constraint], step)
            result = solver.prove()
            if solver.is_UNSAT(result):
                print_separator('v', step, level)
                print_message(f"{constraint}\n", step, level)
                if UNROLL or PROPAGATE is not None:
                    print_message_with_propagation_profile("propagation profile\n", step, level)
                print_separator('^', step, level)
                return

        if step >= kmin:
            # check bad properties from kmin on
            for bad in Bad.bads.values():
                print_message_with_propagation_profile(bad.symbol, step, level)
                solver.push()
                solver.assert_this([bad], step)
                result = solver.prove()
                if solver.is_SAT(result):
                    print_separator('v', step, level)
                    print_message(f"{bad}\n", step, level)
                    solver.print_inputs(Variable.inputs, step, level)
                    if UNROLL or PROPAGATE is not None:
                        print_message_with_propagation_profile("propagation profile\n", step, level)
                    print_separator('^', step, level)

                    if args.analyzor:
                        print("Found bad input; exiting from analyzor mode...")
                        print(f"analyzor#step={step}")
                        print(f"analyzor#bad={bad.symbol}")

                        input_vals = solver.eval_inputs(Variable.inputs, step)
                        for (name, val) in input_vals.items():
                            print(f"analyzor#input:{name}={val}")

                        return

                solver.pop()

        if not args.unconstraining_bad:
            # assert all bad properties as negated constraints
            solver.assert_not_this(Bad.bads.values(), step)

        if args.check_termination and step >= kmin:
            state_change = False
            for next_line in Next.nexts.values():
                # check if state changes
                solver.push()
                solver.assert_is_state_changing(next_line, step)
                result = solver.prove()
                solver.pop()
                if solver.is_SAT(result):
                    state_change = True
                    print_message(f"state change: {next_line}\n", step, level)
                    # compute next step
                    solver.assert_this([next_line], step)
                else:
                    solver.assert_state_is_not_changing(next_line, step)
                if not state_change and next_line is list(Next.nexts.values())[-1]:
                    print_message_with_propagation_profile("no states changed: terminating\n", step, level)
                    return
        else:
            # compute next step
            solver.assert_this(Next.nexts.values(), step)

        if args.print_transition:
            print_message_with_propagation_profile("transitioning\n", step, level)
        else:
            print_message("transitioning", step, level)
        solver.simplify()

        if args.branching and Ite.branching_conditions and Ite.non_branching_conditions:
            print_message_with_propagation_profile("checking branching", step, level)

            solver.push()
            solver.assert_this([Ite.branching_conditions], step)
            branching_result = solver.is_SAT(solver.prove())
            solver.pop()

            solver.push()
            solver.assert_not_this([Ite.non_branching_conditions], step)
            non_branching_result = solver.is_SAT(solver.prove())
            solver.pop()

            if branching_result != non_branching_result:
                if branching_result:
                    solver.assert_this([Ite.branching_conditions], step)
                elif non_branching_result:
                    solver.assert_not_this([Ite.non_branching_conditions], step)

            if branching_result and non_branching_result:
                print_separator('v', step, level)
                print_message("branching:\n", step, level)

                solver.push()
                solver.assert_this([Ite.branching_conditions], step)
                branching_bmc(solver, kmin, kmax, args, step + 1, level + 1)
                solver.pop()

                print_separator('-', step, level)
                print_message("not branching:\n", step, level)

                solver.push()
                solver.assert_not_this([Ite.non_branching_conditions], step)
                branching_bmc(solver, kmin, kmax, args, step + 1, level + 1)
                solver.pop()

                print_separator('^', step, level)
                return

        step += 1

    print_message_with_propagation_profile("reached kmax: terminating\n", step, level)

def bmc(solver, kmin, kmax, args):
    print_separator('-')
    print_message(f"bitme bounded model checking: -kmin {kmin} -kmax {kmax}\n")
    print_separator('-')

    # initialize all states
    solver.assert_this(Init.inits.values(), 0)

    print_message("initializing", 0, 0)
    solver.prove()

    if args.use_bitr:
        fork(kmin, kmax)

    branching_bmc(solver, kmin, kmax, args, 0, 0)

import sys

def try_rotor():
    if is_rotor_present and len(sys.argv) > 1 and sys.argv[1] == '--rotor':
        # just run rotor
        argv = [sys.argv[0]] + sys.argv[2:] # remove --rotor but keep all other arguments
        rotor.main.argtypes = ctypes.c_int, ctypes.POINTER(ctypes.c_char_p)
        rotor.main(len(argv), (ctypes.c_char_p * len(argv))(*[arg.encode('utf-8') for arg in argv]))
        exit(0)

import argparse

def main():
    try_rotor()

    parser = argparse.ArgumentParser(prog='bitme',
        description="bitme is a bounded model checker for BTOR2 models, see github.com/cksystemsteaching/selfie for more details.",
        epilog="bitme is designed to work with BTOR2 models generated by rotor for modeling RISC-V machines and RISC-V code.")

    parser.add_argument('modelfile', type=argparse.FileType('r'))
    parser.add_argument('outputfile', nargs='?', type=argparse.FileType('w', encoding='UTF-8'))

    parser.add_argument('-analyzor', action='store_true')

    parser.add_argument('--use-bitr', action='store_true')

    parser.add_argument('--use-Z3', action='store_true')
    parser.add_argument('--use-bitwuzla', action='store_true')

    parser.add_argument('--use-BVDD', action='store_true')
    parser.add_argument('--use-CFLOBVDD', nargs='*', type=int)

    parser.add_argument('--no-reduction', action='store_true')

    parser.add_argument('--substitute', action='store_true')

    parser.add_argument('--unroll', action='store_true')
    parser.add_argument('-propagate', nargs=1, type=int)

    parser.add_argument('-array', nargs=1, type=int)
    parser.add_argument('--recursive-array', action='store_true')

    parser.add_argument('-kmin', nargs=1, type=int)
    parser.add_argument('-kmax', nargs=1, type=int)

    parser.add_argument('--print-pc', action='store_true') # only for rotor models
    parser.add_argument('--check-termination', action='store_true')
    parser.add_argument('--unconstraining-bad', action='store_true')
    parser.add_argument('--print-transition', action='store_true')
    parser.add_argument('--branching', action='store_true') # only for rotor models

    args = parser.parse_args()

    global LAMBDAS

    LAMBDAS = not args.substitute

    global UNROLL

    UNROLL = args.unroll

    global PROPAGATE

    PROPAGATE = args.propagate[0] if args.propagate and args.propagate[0] >= 0 else None

    btor2.Array.ARRAY_SIZE_BOUND = args.array[0] if args.array else 0
    btor2.Read.READ_ARRAY_ITERATIVELY = not args.recursive_array

    print_separator('#')

    are_there_state_transitions = ValuesParser().parse_btor2(args.modelfile, args.outputfile)

    if args.kmin or args.kmax or args.analyzor:
        kmin = args.kmin[0] if args.kmin else 0
        kmax = args.kmax[0] if args.kmax else 0

        if are_there_state_transitions:
            kmax = max(kmin, kmax)
        else:
            kmin = kmax = 0

        z3_solver = None
        bitwuzla_solver = None

        if z3interface.is_Z3_present:
            z3_solver = z3interface.Z3_Solver(print_message, LAMBDAS, UNROLL)
        if bitwuzlainterface.is_bitwuzla_present:
            bitwuzla_solver = bitwuzlainterface.Bitwuzla_Solver(print_message, LAMBDAS, UNROLL)

        if args.use_BVDD or args.use_CFLOBVDD is None:
            Values.BVDD = True

            Values.BVDD_number_of_inputs = len(Variable.bvdd_input) if Variable.bvdd_input else 1

            # reversing order of input variables
            # Values.BVDD_input = dict([(2**level - 1 - index, var_line)
            #     for index, var_line in Variable.bvdd_input.items()])
            Values.BVDD_input = Variable.bvdd_input
            Values.BVDD_index = dict([(var_line, index)
                for index, var_line in Values.BVDD_input.items()])

            print_separator('-')
            print(f"BVDD configuration: {Values.BVDD_number_of_inputs} input bytes")

        if args.use_CFLOBVDD is not None:
            Values.CFLOBVDD = True

            Values.CFLOBVDD_number_of_inputs = 2**ceil(log2(len(Variable.bvdd_input))) if Variable.bvdd_input else 1

            level = ceil(log2(Values.CFLOBVDD_number_of_inputs))

            Values.CFLOBVDD_fork_level = args.use_CFLOBVDD[0] if len(args.use_CFLOBVDD) > 0 else 0
            Values.CFLOBVDD_swap_level = args.use_CFLOBVDD[1] if len(args.use_CFLOBVDD) > 1 else level
            Values.CFLOBVDD_level = args.use_CFLOBVDD[2] if len(args.use_CFLOBVDD) > 2 else Values.CFLOBVDD_swap_level
            assert 0 <= Values.CFLOBVDD_fork_level <= Values.CFLOBVDD_level, \
                f"invalid fork level {Values.CFLOBVDD_fork_level} for level {Values.CFLOBVDD_level}"
            assert 0 <= Values.CFLOBVDD_swap_level <= Values.CFLOBVDD_level, \
                f"invalid swap level {Values.CFLOBVDD_swap_level} for level {Values.CFLOBVDD_level}"

            # reversing order of input variables
            # Values.CFLOBVDD_input = dict([(2**level - 1 - index, var_line)
            #     for index, var_line in Variable.bvdd_input.items()])
            Values.CFLOBVDD_input = Variable.bvdd_input
            Values.CFLOBVDD_index = dict([(var_line, index)
                for index, var_line in Values.CFLOBVDD_input.items()])

            print_separator('-')
            print(f"CFLOBVDD configuration: {Values.CFLOBVDD_number_of_inputs} input bytes " +
                f"@ level {Values.CFLOBVDD_level}, swap level {Values.CFLOBVDD_swap_level}, and fork level {Values.CFLOBVDD_fork_level}")

        CFLOBVDD.CFLOBVDD.REDUCE = not args.no_reduction

        bitme_solver = Bitme_Solver(z3_solver, bitwuzla_solver)

        if not args.use_Z3 and not args.use_bitwuzla:
            if Variable.bvdd_input:
                bmc(bitme_solver, kmin, kmax, args)

                print_separator('-')
                if Values.BVDD:
                    BVDD.BVDD.print_profile()
                if Values.CFLOBVDD:
                    BVDD.BVDD.print_profile()
                    CFLOBVDD.CFLOBVDD.print_profile()
            else:
                print_separator('-')
                print("model input is unmapped, consider increasing -array")
        else:
            if args.use_Z3 and z3interface.is_Z3_present:
                bmc(z3_solver, kmin, kmax, args)
            if args.use_bitwuzla and bitwuzlainterface.is_bitwuzla_present:
                bmc(bitwuzla_solver, kmin, kmax, args)

    print_separator('#')

if __name__ == '__main__':
    main()
